# coding=utf-8
import copy
import math
import tensorflow.keras.backend as K
from cv2 import cv2
import numpy as np
import numpy as np
import tensorflow as tf
from cv2 import cv2
from tensorflow.python.keras.layers import GlobalMaxPool2D, GlobalMaxPooling2D

from tensorflow.keras.layers import Conv2D, Conv2DTranspose,\
                                    GlobalAveragePooling2D, AveragePooling2D, MaxPool2D, UpSampling2D,\
                                    BatchNormalization, Activation, ReLU, Flatten, Dense, Input,\
                                    Add, Multiply, Concatenate, Softmax

# 获取图像属性
# h, w = img.shape[0:2]


# 调用均值滤波函数
def blur(img):
    result = cv2.blur(img, (5, 5))  # 传入读取的图像和核尺寸
    # cv2.imshow("Noise", img)
    # cv2.imshow("meanFiltering-OpenCV", result)
    # cv2.waitKey(0)
    # cv2.imwrite('去噪效果比较/blur_localEqualHist_256.jpg', result)
    return result


# 中值滤波
def medianBlur(img):
    # 调用OpenCV库函数中的均值滤波函数
    result = cv2.medianBlur(img, 5)  # 传入读取的图像和核尺寸
    # cv2.imshow("src", img)
    # cv2.imshow("medianFiltering-opencv", result)
    # cv2.waitKey(0)
    # cv2.imwrite('去噪效果比较/meanFiltering_256.jpg', result)
    return result


def spilt(a):
    if a / 2 == 0:
        x1 = x2 = a / 2
    else:
        x1 = math.floor(a / 2)
        x2 = a - x1
    return -x1, x2


def original(i, j, k, a, b, img):
    x1, x2 = spilt(a)
    y1, y2 = spilt(b)
    temp = np.zeros(a * b)
    count = 0
    for m in range(x1, x2):
        for n in range(y1, y2):
            if i + m < 0 or i + m > img.shape[0] - 1 or j + n < 0 or j + n > img.shape[1] - 1:
                temp[count] = img[i, j, k]
            else:
                temp[count] = img[i + m, j + n, k]
            count += 1
    return temp


# 最小值滤波
def min_functin(a, b, img):
    img0 = copy.copy(img)
    for i in range(0, img.shape[0]):
        for j in range(2, img.shape[1]):
            for k in range(img.shape[2]):
                temp = original(i, j, k, a, b, img0)
                img[i, j, k] = np.min(temp)
    cv2.imwrite('去噪效果比较/min_functin_localEqualHist.jpg', img)


def max_functin(a, b, img):
    img0 = copy.copy(img)
    for i in range(0, img.shape[0]):
        for j in range(2, img.shape[1]):
            for k in range(img.shape[2]):
                temp = original(i, j, k, a, b, img0)
                img[i, j, k] = np.max(temp)
    cv2.imwrite('去噪效果比较/max_functin_localEqualHist.jpg', img)


# 高斯滤波
def GaussianBlur(img):
    result = cv2.GaussianBlur(img, (5, 5), 1, 1)  # 传入读取的图像和核尺寸
    # cv2.imshow("src", img)
    # cv2.imshow("GaussianBlur-opencv", result)
    # cv2.waitKey(0)
    # cv2.imwrite('去噪效果比较/GaussianBlur_localEqualHist.jpg', result)
    return result


# 非局部均值去噪
def fastNlMeansDenoising(img):
    # h参数调节过滤器强度。大的h值可以完美消除噪点，但同时也可以消除图像细节，较小的h值可以保留细节但也可以保留一些噪点
    h = 25
    # templateWindowSize用于计算权重的模板补丁的像素大小，为奇数，默认7
    templateWindowSize = 10
    # searchWindowSize窗口的像素大小，用于计算给定像素的加权平均值，为奇数，默认21
    searchWindowSize = 21
    result = cv2.fastNlMeansDenoisingColored(img, None, h, h, templateWindowSize, searchWindowSize)
    # cv2.imshow("src", img)
    # cv2.imshow("fastNlMeansDenoisingColored-opencv", result)
    # cv2.waitKey(0)
    # cv2.imwrite('去噪效果比较/fastNlMeansDenoising_500.jpg', result)
    return result


# 小波变换图像去噪，二维一级分解
def dwt2(img):
    from pywt import dwt2, idwt2

    # 对img进行haar小波变换：二维一级分解
    cA, (cH, cV, cD) = dwt2(img, 'haar')

    # 根据小波系数重构回去的图像
    rimg = idwt2((cA, (cH, cV, cD)), 'haar')
    cv2.imwrite('去噪效果比较/dwt2.jpg', rimg)


# 自适应直方图均衡化
def localEqualHist(image):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=5, tileGridSize=(7, 7))
    dst = clahe.apply(gray)
    cv2.imwrite('去噪效果比较/localEqualHist.jpg', dst)


# 方框滤波
def boxBlur(img):
    result = cv2.boxFilter(img, ddepth=-1, ksize=(3, 3))
    # cv2.imwrite('去噪效果比较/boxFilter_256.jpg', result)
    return result


# 双边滤波去噪
def bilateralBlur(img):
    result = cv2.bilateralFilter(img, d=9, sigmaColor=80, sigmaSpace=80)
    # cv2.imwrite('去噪效果比较/bilateralBlur_256.jpg', result)
    return result


# 模型预测
def inference_single_image(model, noisy_image):
    input_image = np.expand_dims(noisy_image, axis=0)
    predicted_image = model.predict(input_image)

    return predicted_image[0]


def redNet(img):
    # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    # 模型路径
    model = tf.keras.models.load_model('去噪效果比较/model/best_REDNet_blindnoise_256x256.h5')
    predicted_image = inference_single_image(model, img)
    return predicted_image


from tensorflow.keras.layers import Conv2D, Conv2DTranspose, \
    GlobalAveragePooling2D, AveragePooling2D, MaxPool2D, UpSampling2D, \
    BatchNormalization, Activation, Flatten, Dense, Input, \
    Add, Multiply, Concatenate, concatenate, Softmax
from tensorflow.keras import initializers, regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.activations import softmax
from tensorflow.python.keras.models import load_model


class Convolutional_block(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.conv_1 = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')
        self.conv_2 = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')
        self.conv_3 = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')
        self.conv_4 = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')

    def call(self, X):
        X_1 = self.conv_1(X)
        X_1 = Activation('relu')(X_1)

        X_2 = self.conv_2(X_1)
        X_2 = Activation('relu')(X_2)

        X_3 = self.conv_3(X_2)
        X_3 = Activation('relu')(X_3)

        X_4 = self.conv_4(X_3)
        X_4 = Activation('relu')(X_4)

        # print('---conv block=',X_4.shape)

        return X_4


# class Channel_attention(tf.keras.layers.Layer):
    # def __init__(self, C=64, **kwargs):
      #   super().__init__(**kwargs)
       #  self.C = C
       #  self.gap = GlobalAveragePooling2D()
      #   self.dense_middle = Dense(units=2, activation='relu')
       #  self.dense_sigmoid = Dense(units=self.C, activation='sigmoid')

  #   def get_config(self):
       #  config = super().get_config().copy()
      #   config.update({
        #     'C': self.C
      #   })
       #  return config

  #   def call(self, X):
      #   v = self.gap(X)
        # print("ca_ after gap =",v.shape)
       #  fc1 = self.dense_middle(v)
        # print("ca_ after fc1 =",fc1.shape)
       #  mu = self.dense_sigmoid(fc1)
        # print("ca_ after fc2 =",mu.shape)

      #   U_out = Multiply()([X, mu])

        # print('---channel attention block=',U_out.shape)

        # return U_out
# 改进 通道注意力改进为CBAM
class Channel_attention(tf.keras.layers.Layer):
    def __init__(self, C=64, **kwargs):
        super().__init__(**kwargs)
        self.C = C
        self.gap = GlobalAveragePooling2D()
        self.max_out = GlobalMaxPooling2D()
        self.dense_middle = Dense(units=2, activation='relu')
        self.dense_sigmoid = Dense(units=self.C, activation='sigmoid')
        # self.avg_pool = tf.reduce_mean()
        # self.max_pool = tf.reduce_max()
        self.conv = Conv2D(1, (7, 7), strides=1, padding='same')
        self.conv_sigmoid = Activation(activation='sigmoid')

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'C': self.C
        })
        return config

    def call(self, X):
        avg_out = self.gap(X)
        max_out = self.max_out(X)
        v = Concatenate()([avg_out, max_out])
        # print("ca_ after gap =",v.shape)
        fc1 = self.dense_middle(v)
        # print("ca_ after fc1 =",fc1.shape)
        mu = self.dense_sigmoid(fc1)
        # print("ca_ after fc2 =",mu.shape)

        U_out = Multiply()([X, mu])

        a = tf.reduce_mean(X, axis=3, keepdims=True)
        m = tf.reduce_max(X, axis=3, keepdims=True)

        # m = self.max_pool(X)
        s = Concatenate(axis=3)([a, m])
        s = self.conv(s)
        SPA_out = self.conv_sigmoid(s)
        CBAM_out = Multiply()([U_out, SPA_out])

        # print('---channel attention block=',U_out.shape)

        return CBAM_out


class Avg_pool_Unet_Upsample_msfe(tf.keras.layers.Layer):
    def __init__(self, avg_pool_size, upsample_rate, **kwargs):
        super().__init__(**kwargs)
        # ---initialization for Avg pooling---
        self.avg_pool = AveragePooling2D(pool_size=avg_pool_size, padding='same')

        # --- initialization for Unet---
        self.deconv_lst = []
        filter = 512
        for i in range(4):
            self.deconv_lst.append(Conv2DTranspose(filters=filter / 2, kernel_size=[3, 3], strides=2, padding='same'))
            filter /= 2

        self.conv_32_down_lst = []
        for i in range(4):
            self.conv_32_down_lst.append(Conv2D(filters=64, kernel_size=[3, 3], activation='relu', padding='same',
                                                kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_64_down_lst = []
        for i in range(4):
            self.conv_64_down_lst.append(Conv2D(filters=128, kernel_size=[3, 3], activation='relu', padding='same',
                                                kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_128_down_lst = []
        for i in range(4):
            self.conv_128_down_lst.append(Conv2D(filters=256, kernel_size=[3, 3], activation='relu', padding='same',
                                                 kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_256_down_lst = []
        for i in range(4):
            self.conv_256_down_lst.append(Conv2D(filters=512, kernel_size=[3, 3], activation='relu', padding='same',
                                                 kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_512_down_lst = []
        for i in range(4):
            self.conv_512_down_lst.append(Conv2D(filters=1024, kernel_size=[3, 3], activation='relu', padding='same',
                                                 kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_32_up_lst = []
        for i in range(3):
            self.conv_32_up_lst.append(Conv2D(filters=64, kernel_size=[3, 3], activation='relu', padding='same',
                                              kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_64_up_lst = []
        for i in range(3):
            self.conv_64_up_lst.append(Conv2D(filters=128, kernel_size=[3, 3], activation='relu', padding='same',
                                              kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_128_up_lst = []
        for i in range(3):
            self.conv_128_up_lst.append(Conv2D(filters=256, kernel_size=[3, 3], activation='relu', padding='same',
                                               kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_256_up_lst = []
        for i in range(3):
            self.conv_256_up_lst.append(Conv2D(filters=512, kernel_size=[3, 3], activation='relu', padding='same',
                                               kernel_regularizer=regularizers.l2(l2=0.001)))

        self.conv_3 = Conv2D(filters=3, kernel_size=[1, 1])

        self.pooling1_unet = MaxPool2D(pool_size=[2, 2], padding='same')
        self.pooling2_unet = MaxPool2D(pool_size=[2, 2], padding='same')
        self.pooling3_unet = MaxPool2D(pool_size=[2, 2], padding='same')
        self.pooling4_unet = MaxPool2D(pool_size=[2, 2], padding='same')

        # ---initialization for Upsampling---
        self.upsample = UpSampling2D(upsample_rate, interpolation='bilinear')

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'avg_pool_size': self.avg_pool_size,
            'upsample_rate': self.upsample_rate
        })
        return config

    def upsample_and_concat(self, x1, x2, i):
        deconv = self.deconv_lst[i](x1)
        deconv_output = Concatenate()([deconv, x2])
        return deconv_output

    def unet(self, input):
        # ---Unet downsampling---
        conv1 = input
        for c_32 in self.conv_32_down_lst:
            conv1 = c_32(conv1)
        pool1 = self.pooling1_unet(conv1)

        conv2 = pool1
        for c_64 in self.conv_64_down_lst:
            conv2 = c_64(conv2)
        pool2 = self.pooling2_unet(conv2)

        conv3 = pool2
        for c_128 in self.conv_128_down_lst:
            conv3 = c_128(conv3)
        pool3 = self.pooling3_unet(conv3)

        conv4 = pool3
        for c_256 in self.conv_256_down_lst:
            conv4 = c_256(conv4)
        pool4 = self.pooling4_unet(conv4)

        conv5 = pool4
        for c_512 in self.conv_512_down_lst:
            conv5 = c_512(conv5)

        # ---Unet upsampling---
        up6 = self.upsample_and_concat(conv5, conv4, 0)
        conv6 = up6
        for c_256 in self.conv_256_up_lst:
            conv6 = c_256(conv6)

        up7 = self.upsample_and_concat(conv6, conv3, 1)
        conv7 = up7
        for c_128 in self.conv_128_up_lst:
            conv7 = c_128(conv7)

        up8 = self.upsample_and_concat(conv7, conv2, 2)
        conv8 = up8
        for c_64 in self.conv_64_up_lst:
            conv8 = c_64(conv8)

        up9 = self.upsample_and_concat(conv8, conv1, 3)
        conv9 = up9
        for c_32 in self.conv_32_up_lst:
            conv9 = c_32(conv9)

        conv10 = self.conv_3(conv9)
        return conv10

    def call(self, X):
        avg_pool = self.avg_pool(X)
        # print("ap =",avg_pool.shape)
        unet = self.unet(avg_pool)
        upsample = self.upsample(unet)
        return upsample


class Multi_scale_feature_extraction(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.msfe_16 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=16, upsample_rate=16)
        self.msfe_8 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=8, upsample_rate=8)
        self.msfe_4 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=4, upsample_rate=4)
        self.msfe_2 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=2, upsample_rate=2)
        self.msfe_1 = Avg_pool_Unet_Upsample_msfe(avg_pool_size=1, upsample_rate=1)

    def call(self, X):
        up_sample_16 = self.msfe_16(X)
        up_sample_8 = self.msfe_8(X)
        up_sample_4 = self.msfe_4(X)
        up_sample_2 = self.msfe_2(X)
        up_sample_1 = self.msfe_1(X)
        msfe_out = Concatenate()([X, up_sample_16, up_sample_8, up_sample_4, up_sample_2, up_sample_1])

        # print('---Multi scale feature extraction block=',msfe_out.shape)
        return msfe_out


class Kernel_selecting_module(tf.keras.layers.Layer):
    def __init__(self, C=21, **kwargs):
        super().__init__(**kwargs)
        self.C = C
        self.c_3 = Conv2D(filters=self.C, kernel_size=(3, 3), strides=1, padding='same',
                          kernel_regularizer=regularizers.l2(0.001))
        self.c_5 = Conv2D(filters=self.C, kernel_size=(5, 5), strides=1, padding='same',
                          kernel_regularizer=regularizers.l2(0.001))
        self.c_7 = Conv2D(filters=self.C, kernel_size=(7, 7), strides=1, padding='same',
                          kernel_regularizer=regularizers.l2(0.001))
        self.gap = GlobalAveragePooling2D()
        self.dense_two = Dense(units=2, activation='relu')
        self.dense_c1 = Dense(units=self.C)
        self.dense_c2 = Dense(units=self.C)
        self.dense_c3 = Dense(units=self.C)

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'C': self.C
        })
        return config

    def call(self, X):
        X_1 = self.c_3(X)
        X_2 = self.c_5(X)
        X_3 = self.c_7(X)

        X_dash = Add()([X_1, X_2, X_3])

        v_gap = self.gap(X_dash)
        v_gap = tf.reshape(v_gap, [-1, 1, 1, self.C])
        fc1 = self.dense_two(v_gap)

        alpha = self.dense_c1(fc1)
        beta = self.dense_c2(fc1)
        gamma = self.dense_c3(fc1)

        before_softmax = concatenate([alpha, beta, gamma], 1)
        # print(before_softmax.shape)
        after_softmax = softmax(before_softmax, axis=1)
        a1 = after_softmax[:, 0, :, :]
        # print(a1)
        a1 = tf.reshape(a1, [-1, 1, 1, self.C])
        # print(a1)
        a2 = after_softmax[:, 1, :, :]
        a2 = tf.reshape(a2, [-1, 1, 1, self.C])
        a3 = after_softmax[:, 2, :, :]
        a3 = tf.reshape(a3, [-1, 1, 1, self.C])

        select_1 = Multiply()([X_1, a1])
        select_2 = Multiply()([X_2, a2])
        select_3 = Multiply()([X_3, a3])

        out = Add()([select_1, select_2, select_3])

        return out


def PRIDpredict(img,model):
 
    predicted_image = inference_single_image(model, img)
    return predicted_image
class Conv_block(tf.keras.layers.Layer):
    def __init__(self, num_filters=200, kernel_size=3, **kwargs):
        super().__init__(**kwargs)
        self.num_filters = num_filters
        self.kernel_size = kernel_size
        self.conv_1 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')
        self.conv_2 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')
        self.conv_3 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')
        self.conv_4 = Conv2D(filters=self.num_filters, kernel_size=self.kernel_size, padding='same')

        self.bn_1 = BatchNormalization()
        self.bn_2 = BatchNormalization()
        self.bn_3 = BatchNormalization()
        self.bn_4 = BatchNormalization()

    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_filters': self.num_filters,
            'kernel_size': self.kernel_size
        })
        return config

    def call(self, X):
        X = self.conv_1(X)
        # X = self.bn_1(X)
        X = ReLU()(X)
        X = self.conv_2(X)
        # X = self.bn_2(X)
        X = ReLU()(X)
        X = self.conv_3(X)
        # X = self.bn_3(X)
        X = ReLU()(X)
        # X = self.conv_4(X)
        # # X = self.bn_4(X)
        # X = ReLU()(X)

        return X


class DWT_downsampling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x):
        """
        The following calculations for DWT are inspired from,
        https://github.com/AureliePeng/Keras-WaveletTransform/blob/master/models/DWT.py
        """
        x1 = x[:, 0::2, 0::2, :]  # x(2i−1, 2j−1)
        x2 = x[:, 1::2, 0::2, :]  # x(2i, 2j-1)
        x3 = x[:, 0::2, 1::2, :]  # x(2i−1, 2j)
        x4 = x[:, 1::2, 1::2, :]  # x(2i, 2j)

        x_LL = x1 + x2 + x3 + x4
        x_LH = -x1 - x3 + x2 + x4
        x_HL = -x1 + x3 - x2 + x4
        x_HH = x1 - x3 - x2 + x4

        return Concatenate(axis=-1)([x_LL, x_LH, x_HL, x_HH])


class IWT_upsampling(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x):
        """
        The following calculations for IWT are inspired from,
        https://github.com/AureliePeng/Keras-WaveletTransform/blob/master/models/DWT.py
        """
        x_LL = x[:, :, :, 0:x.shape[3] // 4]
        x_LH = x[:, :, :, x.shape[3] // 4:x.shape[3] // 4 * 2]
        x_HL = x[:, :, :, x.shape[3] // 4 * 2:x.shape[3] // 4 * 3]
        x_HH = x[:, :, :, x.shape[3] // 4 * 3:]

        x1 = (x_LL - x_LH - x_HL + x_HH) / 4
        x2 = (x_LL - x_LH + x_HL - x_HH) / 4
        x3 = (x_LL + x_LH - x_HL - x_HH) / 4
        x4 = (x_LL + x_LH + x_HL + x_HH) / 4

        y1 = K.stack([x1, x3], axis=2)
        y2 = K.stack([x2, x4], axis=2)
        shape = K.shape(x)
        return K.reshape(K.concatenate([y1, y2], axis=-1),
                         K.stack([shape[0], shape[1] * 2, shape[2] * 2, shape[3] // 4]))


def inference_single_image(noisy_image, model):
    input_image = np.expand_dims(noisy_image, axis=0)
    predicted_image = model.predict(input_image)

    return predicted_image[0]


if __name__ == '__main__':
    # GaussianBlur(img)
    import time
    import os
    from skimage.metrics import peak_signal_noise_ratio
    import pandas as pd

    path = 'G:/code/imageData_changedSize_256/imageData_changedSize_256/'

    save_path = 'G:/code/'
    model = tf.keras.models.load_model('G:/code/model/PRIDModelV3/best_PRIDNet_blindnoise_256x256.h5',
                                custom_objects={'Convolutional_block': Convolutional_block,
                                                                                                        'Channel_attention':Channel_attention,
                                                                                                        'Avg_pool_Unet_Upsample_msfe':Avg_pool_Unet_Upsample_msfe,
                                                                                                        'Multi_scale_feature_extraction':Multi_scale_feature_extraction,
                                                                                                        'Kernel_selecting_module':Kernel_selecting_module})
    psnr_list = []
    i = 0
    for filename in os.listdir(path):
        img = cv2.imread(path + filename)
        predict = inference_single_image(img,model)
        value = peak_signal_noise_ratio(img, predict)
        print(str(value) + " " + str(i))
        psnr_list.append(value)
        i = i + 1
    avg_psnr = sum(psnr_list) / len(psnr_list)
    print('平均去噪指标')
    print(avg_psnr)
    dt = pd.DataFrame(psnr_list)
    dt.to_csv(save_path + 'PRID_CBAM网络去噪去噪.csv', index=None)
    avg_psnr = sum(psnr_list) / len(psnr_list)
    print('平均去噪指标')
    print(avg_psnr)
    # print(time.time() - start)
